"""
Score network for the cube and sphere datasets.
"""

# pylint: disable=missing-function-docstring
# pylint: disable=missing-class-docstring

from datetime import datetime

import numpy as np
import torch
import torch.nn as nn

from diffusion_bandit.diffusion import DiffusionProcess


class GaussianFourierProjection(nn.Module):
    """Gaussian random features for encoding time steps."""

    def __init__(self, embed_dim: int, scale: int = 30.0):
        super().__init__()
        self.weights = nn.Parameter(
            torch.randn(embed_dim // 2) * scale, requires_grad=False
        )

    def forward(self, x_data):
        x_proj = x_data * self.weights[None, :] * 2 * np.pi
        return torch.cat([torch.sin(x_proj), torch.cos(x_proj)], dim=-1)


class LowRankLinear(nn.Module):
    def __init__(self, in_features, out_features, rank):
        super(LowRankLinear, self).__init__()
        self.U = nn.Linear(in_features, rank, bias=False)
        self.V = nn.Linear(rank, out_features, bias=False)

    def forward(self, x):
        return self.V(self.U(x))


class ResidualBlock(nn.Module):
    def __init__(self, dim: int, dropout: float):
        super(ResidualBlock, self).__init__()
        self.layer_norm1 = nn.LayerNorm(dim)
        self.layer_norm2 = nn.LayerNorm(dim)
        self.dropout = nn.Dropout(p=dropout)
        self.activation = nn.GELU()
        self.linear1 = nn.Linear(dim, dim)
        self.linear2 = nn.Linear(dim, dim)

    def forward(self, input_tensor: torch.Tensor) -> torch.Tensor:
        residual = input_tensor
        out = self.layer_norm1(input_tensor)
        out = self.linear1(out)
        out = self.activation(out)
        out = self.dropout(out)
        out = self.linear2(out)
        out = self.layer_norm2(out)
        out = self.activation(out)
        out = self.dropout(out)
        return out + residual


class ResidualScore(nn.Module):
    def __init__(
        self,
        d_ext: int,
        hidden_size: int,
        num_layers: int,
        dropout: float,
        diffusion_process: DiffusionProcess,
        **kwargs,
    ):
        super(ResidualScore, self).__init__()

        self.diffusion_process = diffusion_process

        self.time_embedding = GaussianFourierProjection(embed_dim=hidden_size)

        self.act = nn.GELU()

        self.input_layer = nn.Sequential(
            nn.Linear(d_ext, hidden_size),
            nn.LayerNorm(hidden_size),
            nn.GELU(),
            nn.Linear(hidden_size, hidden_size),
            nn.LayerNorm(hidden_size),
            nn.GELU(),
        )

        self.joint_embedding = nn.Sequential(
            nn.Linear(2 * hidden_size, 2 * hidden_size),
            nn.LayerNorm(2 * hidden_size),
            nn.GELU(),
            nn.Linear(2 * hidden_size, hidden_size),
            nn.LayerNorm(hidden_size),
            nn.GELU(),
            nn.Linear(hidden_size, hidden_size),
        )

        self.residual_blocks = nn.ModuleList(
            [ResidualBlock(hidden_size, dropout) for _ in range(num_layers)]
        )

        self.output_layer = nn.Sequential(
            nn.Linear(hidden_size, hidden_size),
            nn.GELU(),
            nn.Linear(hidden_size, hidden_size),
            nn.GELU(),
            nn.Linear(hidden_size, d_ext),
        )

    def forward(self, time: torch.Tensor, x_batch: torch.Tensor) -> torch.Tensor:
        processed_input = self.input_layer(x_batch)
        time_embedding = self.act(self.time_embedding(time))

        combined_input = torch.cat([processed_input, time_embedding], dim=-1)

        joint_embd = self.joint_embedding(combined_input)

        # Pass through residual blocks
        block_output = joint_embd
        for block in self.residual_blocks:
            block_output = block(block_output)

        # Final output layer with scaling
        output = self.output_layer(block_output)
        std = self.diffusion_process.marginal_prob_std(time.squeeze(-1))[:, None]
        scaled_output = output / std

        return scaled_output  # (batch_size, d_ext)


# class BottleneckScoreNetwork(nn.Module):
#     def __init__(
#         self,
#         d_ext: int,
#         hidden_size: int,
#         num_layers: int,
#         dropout: float,
#         diffusion_process,
#     ):
#         super(BottleneckScoreNetwork, self).__init__()
#         self.diffusion_process = diffusion_process

#         # Time embedding
#         self.time_embedding = GaussianFourierProjection(embed_dim=hidden_size)

#         self.activation = nn.GELU()

#         # Input projection to bottleneck dimension (using LowRankLinear)
#         rank = min(d_ext, hidden_size) // 4  # Adjust rank as needed
#         self.input_layer = LowRankLinear(d_ext, hidden_size, rank=rank)

#         # Hidden layers in bottleneck dimension
#         self.hidden_layers = nn.ModuleList()
#         for _ in range(num_layers):
#             self.hidden_layers.append(
#                 nn.Sequential(
#                     nn.LayerNorm(hidden_size),
#                     nn.GELU(),
#                     nn.Linear(hidden_size, hidden_size),
#                     nn.Dropout(dropout),
#                 )
#             )

#         self.hidden_norm = nn.LayerNorm(hidden_size)

#         # Output projection back to d_ext (using LowRankLinear)
#         self.output_layer = LowRankLinear(hidden_size, d_ext, rank=rank)


#     def forward(self, time: torch.Tensor, x_batch: torch.Tensor) -> torch.Tensor:
#         # Embed time
#         time_emb = self.activation(self.time_embedding(time))

#         # Project input to bottleneck dimension
#         x_proj = self.input_layer(x_batch)

#         # Combine with time embedding
#         h_batch = x_proj + time_emb

#         # Pass through hidden layers with residual connections
#         for layer in self.hidden_layers:
#             residual = h_batch
#             h_batch = layer(h_batch)
#             h_batch = h_batch + residual

#         output = self.output_layer(h_batch)

#         # Scale output by the diffusion process's standard deviation
#         std = self.diffusion_process.marginal_prob_std(time.squeeze(-1))[:, None]
#         scaled_output = output / std

#         return scaled_output


# class BottleneckScoreNetwork(nn.Module):
#     def __init__(
#         self,
#         d_ext: int,
#         hidden_size: int,
#         num_layers: int,
#         dropout: float,
#         diffusion_process,
#     ):
#         super(BottleneckScoreNetwork, self).__init__()
#         self.diffusion_process = diffusion_process

#         # Time embedding
#         self.time_embedding = GaussianFourierProjection(embed_dim=hidden_size)

#         self.activation = nn.GELU()

#         # Input projection to bottleneck dimension
#         self.input_layer = nn.Linear(d_ext, hidden_size)

#         # Hidden layers in bottleneck dimension
#         self.hidden_layers = nn.ModuleList()
#         for _ in range(num_layers):
#             self.hidden_layers.append(
#                 nn.Sequential(
#                     nn.LayerNorm(hidden_size),
#                     nn.GELU(),
#                     nn.Linear(hidden_size, hidden_size),
#                     nn.Dropout(dropout),
#                 )
#             )

#         self.hidden_norm = nn.LayerNorm(hidden_size)

#         # Output projection back to d_ext
#         self.output_layer = nn.Linear(hidden_size, d_ext)

#     def forward(self, time: torch.Tensor, x_batch: torch.Tensor) -> torch.Tensor:
#         # Embed time
#         time_emb = self.activation(self.time_embedding(time))

#         # Project input to bottleneck dimension
#         x_proj = self.input_layer(x_batch)

#         # Combine with time embedding
#         h_batch = x_proj + time_emb

#         # Pass through hidden layers with residual connections
#         for layer in self.hidden_layers:
#             residual = h_batch
#             h_batch = layer(h_batch)
#             h_batch = h_batch + residual

#         output = self.output_layer(h_batch)

#         # Scale output by the diffusion process's standard deviation
#         std = self.diffusion_process.marginal_prob_std(time.squeeze(-1))[:, None]
#         scaled_output = output / std

#         return scaled_output


class BottleneckScoreNetwork(nn.Module):
    def __init__(
        self,
        d_ext: int,
        hidden_size: int,
        num_layers: int,
        dropout: float,
        diffusion_process,
        factor: int = 5,  # Factor to reduce the hidden size for bottleneck layers
    ):
        super(BottleneckScoreNetwork, self).__init__()
        self.diffusion_process = diffusion_process

        # Time embedding
        self.time_embedding = GaussianFourierProjection(embed_dim=hidden_size)

        self.activation = nn.GELU()

        # Input projection to hidden size
        self.input_layer = nn.Linear(d_ext, hidden_size)

        # Hidden layers with low-rank bottleneck structure
        bottleneck_dim = hidden_size // factor
        self.hidden_layers = nn.ModuleList()
        for _ in range(num_layers):
            self.hidden_layers.append(
                nn.Sequential(
                    nn.LayerNorm(hidden_size),
                    nn.LeakyReLU(),
                    nn.Linear(hidden_size, hidden_size),
                    nn.Dropout(dropout),
                )
            )

        self.hidden_norm = nn.LayerNorm(hidden_size)

        # Output projection back to d_ext
        self.output_layer = nn.Linear(hidden_size, d_ext)

    def forward(self, time: torch.Tensor, x_batch: torch.Tensor) -> torch.Tensor:
        # Embed time
        time_emb = self.activation(self.time_embedding(time))

        # Project input to hidden size
        x_proj = self.input_layer(x_batch)

        # Combine with time embedding
        h_batch = x_proj + time_emb

        # Pass through hidden layers with residual connections
        for layer in self.hidden_layers:
            residual = h_batch
            h_batch = layer(h_batch)
            h_batch = h_batch + residual

        output = self.output_layer(h_batch)

        # Scale output by the diffusion process's standard deviation
        std = self.diffusion_process.marginal_prob_std(time.squeeze(-1))[:, None]
        scaled_output = output / std

        return scaled_output


def get_score_network(name, **kwargs):
    if name == "residual":
        return ResidualScore(**kwargs)
    if name == "bottleneck":
        return BottleneckScoreNetwork(**kwargs)


def generate_model_name(config):
    shape = config.dataset.shape
    d_ext = config.dataset.d_ext
    d_intr = config.dataset.d_intr
    radius = config.dataset.radius
    num_samples = config.dataset.num_samples
    score_model_name = config.score_model.name
    num_layers = config.score_model.num_layers
    sigma = config.diffusion.sigma
    learn_rate = config.optimizer.learn_rate
    n_epochs = config.training.n_epochs
    projection_type = config.projection.type
    hidden = config.score_model.hidden_size
    dropout = config.score_model.dropout

    base_name = (
        f"score_{score_model_name}_"
        f"{shape}_extr{d_ext}_intr{d_intr}_n{num_samples}_r{radius}"
        f"{projection_type}_layers{num_layers}_hidden{hidden}_do{dropout}"
        f"sigma{sigma}_lr{learn_rate}_epoch{n_epochs}"
    )

    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")

    full_model_name = f"{base_name}_{timestamp}"

    return full_model_name
